import matplotlib.pyplot as plt # type: ignore
import numpy as np # type: ignore
import pandas as pd # type: ignore

def plot_colormaps(freqi = 6000):
    n = 1000
    
    fig, axarr = plt.subplots(2, 2, sharex='col', figsize = [7,6])
    fig.subplots_adjust(right=0.85)
    
    for i, model in zip([0,1], ['box', 'slice']):
        if model == 'box':
            p_above_df = pd.read_csv('comsol_raw/p_above_' + model + '.csv', skiprows = 8, index_col = [0,1,2])
            p_below_df = pd.read_csv('comsol_raw/p_below_' + model + '.csv', skiprows = 8, index_col = [0,1,2])
        else:
            p_above_df = pd.read_csv('comsol_raw/p_above_' + model + '_st_adjusted.csv', skiprows = 8, index_col = [0,1,2])
            p_below_df = pd.read_csv('comsol_raw/p_below_' + model + '_st_adjusted.csv', skiprows = 8, index_col = [0,1,2])
        
        p_above_re = p_above_df.loc[:, 'real(ta.p_t) (Pa) @ freq='+str(freqi)].iloc[0]
        p_above_im = p_above_df.loc[:, 'imag(ta.p_t) (Pa) @ freq='+str(freqi)].iloc[0]

        p_below_re = p_below_df.loc[:, 'real(ta.p_t) (Pa) @ freq='+str(freqi)].iloc[0]
        p_below_im = p_below_df.loc[:, 'imag(ta.p_t) (Pa) @ freq='+str(freqi)].iloc[0]

        p_adjust_re = (p_above_re + p_below_re)/2
        p_adjust_im = (p_above_im + p_below_im)/2

        print(i, model)
        print(p_adjust_re, p_adjust_im)

        ###### mag ######
        print('mag')
        ax = axarr[0, i]

        df = pd.read_csv('comsol_raw/'+model+'_colormap_pressure.csv', skiprows = 8, index_col = [0,1])
        x = np.sort(np.asarray(list(set(df.index.get_level_values(0)))))
        y = np.sort(np.asarray(list(set(df.index.get_level_values(1)))))

        X,Y = np.meshgrid(x,y)
        
        z_re = np.asarray(df.loc[:, 'real(p) (Pa) @ freq='+str(freqi)])
        z_im = np.asarray(df.loc[:, 'imag(p) (Pa) @ freq='+str(freqi)])

        z_adjust = z_re - p_adjust_re + 1j*(z_im - p_adjust_im)

        z_mag = np.abs(np.asarray(z_adjust))
        print(np.nanmax(z_mag), np.nanmin(z_mag))

        Z = np.reshape(z_mag, (n, n))
        
        im = ax.pcolormesh(Y, X, Z, cmap = 'jet', vmin = 0, vmax = 16)
        ax.set_xlim(-0.125, 0.125)
        ax.set_ylim(-0.1, 0.125)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        if i == 1:
            cbar_ax = fig.add_axes([0.9, 0.53, 0.03, 0.35])
            cbar = fig.colorbar(im, cax = cbar_ax)
            cbar.ax.set_yticks([0, 2, 4, 6, 8, 10, 12, 14, 16])
            # cbar.ax.set_yticklabels(['-1','', '0','','1'])

    
        ax.set_aspect(1)

        ###### phase ######
        print('phase')
        ax = axarr[1, i]
        
        z_adjust = z_re - p_adjust_re + 1j*(z_im - p_adjust_im)


        z_phase = np.angle(z_adjust)/2/np.pi
        print(np.nanmax(z_phase), np.nanmin(z_phase))

        k = 0
        while np.isnan(z_phase[k]):
            k = k + 1
        
        phase_shift = z_phase[k]
        print(k, phase_shift)

        
        z_phase = z_phase - phase_shift

        for k in range(len(z_phase)):
            while z_phase[k] < - 0.2:
                z_phase[k] = z_phase[k] + 1
            
        for k in range(len(z_phase)):
            while z_phase[k] > 0.1:
                z_phase[k] = z_phase[k] - 1
        
        Z = np.reshape(z_phase, (n, n))

        im = ax.pcolormesh(Y, X, Z, cmap = 'jet', vmin = -1, vmax = 1)
        ax.set_xlim(-0.125, 0.125)
        ax.set_ylim(-0.1, 0.125)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
    
        ax.set_aspect(1)
        if i == 1:
            cbar_ax = fig.add_axes([0.9, 0.11, 0.03, 0.35])
            cbar = fig.colorbar(im, cax = cbar_ax)
            cbar.ax.set_yticks([-1, -0.5,  0, 0.5, 1])
            cbar.ax.set_yticklabels(['-1','', '0','','1'])
   
    for ax in axarr[:, 1]:
        ax.spines['left'].set_visible(False)
        ax.set_yticks([])    

    # fig.tight_layout()
    plt.savefig('manufigs/colormap_pressure_' + str(freqi)+'Hz.png')
    plt.show()

for freqi in [4000]:
    print(freqi)
    plot_colormaps(freqi)
